{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "6.8. 长短期记忆(LSTM).ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/chongzicbo/Dive-into-Deep-Learning-tf.keras/blob/master/6.8.%20%E9%95%BF%E7%9F%AD%E6%9C%9F%E8%AE%B0%E5%BF%86(LSTM).ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LwQk3RY81Jtc",
        "colab_type": "text"
      },
      "source": [
        "##6.8. 长短期记忆（LSTM）\n",
        "本节将介绍另一种常用的门控循环神经网络：长短期记忆（long short-term memory，LSTM）[1]。它比门控循环单元的结构稍微复杂一点。\n",
        "\n",
        "###6.8.1. 长短期记忆\n",
        "LSTM 中引入了3个门，即输入门（input gate）、遗忘门（forget gate）和输出门（output gate），以及与隐藏状态形状相同的记忆细胞（某些文献把记忆细胞当成一种特殊的隐藏状态），从而记录额外的信息。\n",
        "\n",
        "####6.8.1.1. 输入门、遗忘门和输出门\n",
        "与门控循环单元中的重置门和更新门一样，如图6.7所示，长短期记忆的门的输入均为当前时间步输入 $X_t$ 与上一时间步隐藏状态 $H_{t-1}$ ，输出由激活函数为sigmoid函数的全连接层计算得到。如此一来，这3个门元素的值域均为 [0,1] 。\n",
        "\n",
        "<img src=\"https://zh.gluon.ai/_images/lstm_0.svg\" width=\"500\"/>\n",
        "\n",
        "<center>图 6.7 长短期记忆中输入门、遗忘门和输出门的计算</center>\n",
        "\n",
        "具体来说，假设隐藏单元个数为 $h$ ，给定时间步 $t$ 的小批量输入$\\boldsymbol{X}_t \\in \\mathbb{R}^{n \\times d}$（样本数为 $n$ ，输入个数为 $d$ ）和上一时间步隐藏状态$\\boldsymbol{H}_{t-1} \\in \\mathbb{R}^{n \\times h}$。 时间步 $t$ 的输入门$\\boldsymbol{I}_t \\in \\mathbb{R}^{n \\times h}$、遗忘门$\\boldsymbol{F}_t \\in \\mathbb{R}^{n \\times h}$和输出门$ \\boldsymbol{O}_t \\in \\mathbb{R}^{n \\times h}$分别计算如下：\n",
        "$$\n",
        "\\begin{split}\\begin{aligned}\n",
        "\\boldsymbol{I}_t &= \\sigma(\\boldsymbol{X}_t \\boldsymbol{W}_{xi} + \\boldsymbol{H}_{t-1} \\boldsymbol{W}_{hi} + \\boldsymbol{b}_i),\\\\\n",
        "\\boldsymbol{F}_t &= \\sigma(\\boldsymbol{X}_t \\boldsymbol{W}_{xf} + \\boldsymbol{H}_{t-1} \\boldsymbol{W}_{hf} + \\boldsymbol{b}_f),\\\\\n",
        "\\boldsymbol{O}_t &= \\sigma(\\boldsymbol{X}_t \\boldsymbol{W}_{xo} + \\boldsymbol{H}_{t-1} \\boldsymbol{W}_{ho} + \\boldsymbol{b}_o),\n",
        "\\end{aligned}\\end{split}\n",
        "$$\n",
        "其中的 $\\boldsymbol{W}_{xi}, \\boldsymbol{W}_{xf}, \\boldsymbol{W}_{xo} \\in \\mathbb{R}^{d \\times h}$和$\\boldsymbol{W}_{hi}, \\boldsymbol{W}_{hf}, \\boldsymbol{W}_{ho} \\in \\mathbb{R}^{h \\times h}$是权重参数，$\\boldsymbol{b}_i, \\boldsymbol{b}_f, \\boldsymbol{b}_o \\in \\mathbb{R}^{1 \\times h}$是偏差参数。\n",
        "\n",
        "####6.8.1.2. 候选记忆细胞\n",
        "接下来，长短期记忆需要计算候选记忆细胞$\\tilde{\\boldsymbol{C}}_t$。它的计算与上面介绍的3个门类似，但使用了值域在 $[-1,1]$ 的tanh函数作为激活函数，如图6.8所示。\n",
        "\n",
        "<img src=\"https://zh.gluon.ai/_images/lstm_1.svg\" width=\"500\"/>\n",
        "\n",
        "<center>图 6.8 长短期记忆中候选记忆细胞的计算</center>\n",
        "\n",
        "具体来说，时间步 $t$ 的候选记忆细胞$\\tilde{\\boldsymbol{C}}_t \\in \\mathbb{R}^{n \\times h}$的计算为\n",
        "$$\n",
        "\\tilde{\\boldsymbol{C}}_t = \\text{tanh}(\\boldsymbol{X}_t \\boldsymbol{W}_{xc} + \\boldsymbol{H}_{t-1} \\boldsymbol{W}_{hc} + \\boldsymbol{b}_c),\n",
        "$$\n",
        "其中$\\boldsymbol{W}_{xc} \\in \\mathbb{R}^{d \\times h}$和$\\boldsymbol{W}_{hc} \\in \\mathbb{R}^{h \\times h}$是权重参数，$\\boldsymbol{b}_c \\in \\mathbb{R}^{1 \\times h}$是偏差参数。\n",
        "\n",
        "####6.8.1.3. 记忆细胞\n",
        "我们可以通过元素值域在 $[0,1]$ 的输入门、遗忘门和输出门来控制隐藏状态中信息的流动，这一般也是通过使用按元素乘法（符号为$\\odot$)来实现的。当前时间步记忆细胞 $\\boldsymbol{C}_t \\in \\mathbb{R}^{n \\times h}$的计算组合了上一时间步记忆细胞和当前时间步候选记忆细胞的信息，并通过遗忘门和输入门来控制信息的流动：\n",
        "$$\n",
        "\\boldsymbol{C}_t = \\boldsymbol{F}_t \\odot \\boldsymbol{C}_{t-1} + \\boldsymbol{I}_t \\odot \\tilde{\\boldsymbol{C}}_t.\n",
        "$$\n",
        "如图6.9所示，遗忘门控制上一时间步的记忆细胞$\\boldsymbol{C}_{t-1}$中的信息是否传递到当前时间步，而输入门则控制当前时间步的输入 $X_t$ 通过候选记忆细胞$\\tilde{\\boldsymbol{C}}_t$如何流入当前时间步的记忆细胞。如果遗忘门一直近似1且输入门一直近似0，过去的记忆细胞将一直通过时间保存并传递至当前时间步。这个设计可以应对循环神经网络中的梯度衰减问题，并更好地捕捉时间序列中时间步距离较大的依赖关系。\n",
        "\n",
        "<img src=\"https://zh.gluon.ai/_images/lstm_2.svg\" width=\"500\"/>\n",
        "\n",
        "<center>图 6.9 长短期记忆中记忆细胞的计算。这里的 $\\odot$是按元素乘法</center>\n",
        "\n",
        "####6.8.1.4. 隐藏状态\n",
        "有了记忆细胞以后，接下来我们还可以通过输出门来控制从记忆细胞到隐藏状态$\\boldsymbol{H}_t \\in \\mathbb{R}^{n \\times h}$ 的信息的流动：\n",
        "$$\n",
        "\\boldsymbol{H}_t = \\boldsymbol{O}_t \\odot \\text{tanh}(\\boldsymbol{C}_t).\n",
        "$$\n",
        "这里的tanh函数确保隐藏状态元素值在-1到1之间。需要注意的是，当输出门近似1时，记忆细胞信息将传递到隐藏状态供输出层使用；当输出门近似0时，记忆细胞信息只自己保留。图6.10展示了长短期记忆中隐藏状态的计算。\n",
        "\n",
        "<img src=\"https://zh.gluon.ai/_images/lstm_3.svg\" width=\"500\"/>\n",
        "\n",
        "<center>图 6.10 长短期记忆中隐藏状态的计算。这里的 ⊙ 是按元素乘法</center>\n",
        "\n",
        "###6.8.2. 读取数据集\n",
        "下面我们开始实现并展示长短期记忆。和前几节中的实验一样，这里依然使用周杰伦歌词数据集来训练模型作词。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8-MMjILo1ieu",
        "colab_type": "code",
        "outputId": "d7af8b50-c802-4455-e64f-2cd99753e428",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 62
        }
      },
      "source": [
        "%matplotlib inline\n",
        "import math\n",
        "import tensorflow as tf\n",
        "import numpy as np\n",
        "from IPython import display\n",
        "import matplotlib.pyplot as plt\n",
        "from tensorflow import keras\n",
        "from tensorflow.keras import losses\n",
        "from tensorflow.data import Dataset\n",
        "import time\n",
        "import random\n",
        "import zipfile"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<p style=\"color: red;\">\n",
              "The default version of TensorFlow in Colab will soon switch to TensorFlow 2.x.<br>\n",
              "We recommend you <a href=\"https://www.tensorflow.org/guide/migrate\" target=\"_blank\">upgrade</a> now \n",
              "or ensure your notebook will continue to use TensorFlow 1.x via the <code>%tensorflow_version 1.x</code> magic:\n",
              "<a href=\"https://colab.research.google.com/notebooks/tensorflow_version.ipynb\" target=\"_blank\">more info</a>.</p>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Xu1ZUk6p1xwJ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "tf.enable_eager_execution()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wEW9Ai6L1z2l",
        "colab_type": "code",
        "outputId": "a1c88d86-d9c2-4493-8707-258b6ce53126",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 121
        }
      },
      "source": [
        "def load_data_jay_lyrics():\n",
        "  from google.colab import drive\n",
        "  drive.mount('/content/drive')\n",
        "  with zipfile.ZipFile('/content/drive/My Drive/data/d2l-zh-tensoflow/jaychou_lyrics.txt.zip')as zin:\n",
        "    with zin.open('jaychou_lyrics.txt') as f:\n",
        "      corpus_chars=f.read().decode('utf-8')\n",
        "  corpus_chars=corpus_chars.replace('\\n',' ').replace('\\r',' ')\n",
        "  corpus_chars=corpus_chars[0:10000]\n",
        "  idx_to_char=list(set(corpus_chars))\n",
        "  char_to_idx=dict([(char,i) for i,char in enumerate(idx_to_char)])\n",
        "  vocab_size=len(char_to_idx)\n",
        "  corpus_indices=[char_to_idx[char] for char in corpus_chars]\n",
        "  return corpus_indices,char_to_idx,idx_to_char,vocab_size\n",
        "\n",
        "(corpus_indices,char_to_idx,idx_to_char,vocab_size)=load_data_jay_lyrics() "
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly\n",
            "\n",
            "Enter your authorization code:\n",
            "··········\n",
            "Mounted at /content/drive\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ohRSeSzp4nQ6",
        "colab_type": "text"
      },
      "source": [
        "###6.8.3. 从零开始实现\n",
        "我们先介绍如何从零开始实现长短期记忆。\n",
        "\n",
        "####6.8.3.1. 初始化模型参数\n",
        "下面的代码对模型参数进行初始化。超参数num_hiddens定义了隐藏单元的个数。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8Bq_p_b712Sc",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "num_inputs,num_hiddens,num_outputs=vocab_size,256,vocab_size\n",
        "def get_params():\n",
        "  def _one(shape):\n",
        "    return tf.Variable(tf.random.normal(stddev=0.01,shape=shape))\n",
        "\n",
        "  def _three():\n",
        "    return (_one((num_inputs,num_hiddens)),\n",
        "        _one((num_hiddens,num_hiddens)),\n",
        "        tf.Variable(tf.zeros(num_hiddens)))\n",
        "    \n",
        "  W_xi,W_hi,b_i=_three() #输入门参数\n",
        "  W_xf,W_hf,b_f=_three() #遗忘门参数\n",
        "  W_xo,W_ho,b_o=_three() #输出门参数\n",
        "  W_xc,W_hc,b_c=_three() #候选记忆细胞参数\n",
        "\n",
        "  #输出层参数\n",
        "  W_hq=_one((num_hiddens,num_outputs))\n",
        "  b_q=tf.Variable(tf.zeros(num_outputs))\n",
        "\n",
        "  params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,\n",
        "              b_c, W_hq, b_q]\n",
        "  return params"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_odoyXwL4tqx",
        "colab_type": "text"
      },
      "source": [
        "###6.8.4. 定义模型\n",
        "在初始化函数中，长短期记忆的隐藏状态需要返回额外的形状为(批量大小, 隐藏单元个数)的值为0的记忆细胞。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "OPLkNl-Cdiqd",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def init_lstm_state(batch_size,num_hiddens):\n",
        "  return (tf.zeros(shape=(batch_size,num_hiddens)),tf.zeros(shape=(batch_size,num_hiddens)))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n-p5gBTg4xeB",
        "colab_type": "text"
      },
      "source": [
        "下面根据长短期记忆的计算表达式定义模型。需要注意的是，只有隐藏状态会传递到输出层，而记忆细胞不参与输出层的计算。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Nnii4aLod4gE",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def lstm(inputs,state,params):\n",
        "  W_xi,W_hi,b_i,W_xf,W_hf,b_f,W_xo,W_ho,b_o,W_xc,W_hc,b_c,W_hq,b_q=params\n",
        "  (H,C)=state\n",
        "  outputs=[]\n",
        "\n",
        "  for X in inputs:\n",
        "    I=tf.sigmoid(tf.matmul(X,W_xi)+tf.matmul(H,W_hi)+b_i)\n",
        "    F=tf.sigmoid(tf.matmul(X,W_xf)+tf.matmul(H,W_hf)+b_f)\n",
        "    O=tf.sigmoid(tf.matmul(X,W_xo)+tf.matmul(H,W_ho)+b_o)\n",
        "    C_tilda=tf.tanh(tf.matmul(X,W_xc)+tf.matmul(H,W_hc)+b_c)\n",
        "    C=F*C+I*C_tilda\n",
        "    H=O*tf.tanh(C)\n",
        "    Y=tf.matmul(H,W_hq)+b_q\n",
        "    outputs.append(Y)\n",
        "  return outputs,(H,C)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MLm9h1k8fvfu",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "num_epochs, num_steps, batch_size, lr, clipping_theta = 160, 35, 32, 1e2, 1e-2\n",
        "pred_period, pred_len, prefixes = 40, 50, ['分开', '不分开']"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sb_OVwyc461g",
        "colab_type": "text"
      },
      "source": [
        "####6.8.4.1. 训练模型并创作歌词\n",
        "同上一节一样，我们在训练模型时只使用相邻采样。设置好超参数后，我们将训练模型并根据前缀“分开”和“不分开”分别创作长度为50个字符的一段歌词。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "R_8YW-6ZfyFe",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def to_onehot(X,size):\n",
        "  return [tf.one_hot(x,size) for x in tf.transpose(X)]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QYaAwobuf1Kt",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def predict_rnn(prefix,num_chars,rnn,params,init_rnn_state,num_hiddens,vocab_size,idx_to_char,char_to_idx):\n",
        "  state=init_rnn_state(1,num_hiddens)\n",
        "  output=[char_to_idx[prefix[0]]]\n",
        "  for t in range(num_chars+len(prefix)-1):\n",
        "    #将上一时间步的输出作为当前时间步的输入\n",
        "    X=to_onehot(tf.reshape(tf.constant([output[-1]]),shape=(1,1)),vocab_size)\n",
        "    # print(X[0].shape)\n",
        "    #计算输出和更新隐藏状态\n",
        "    (Y,state)=rnn(X,state,params)\n",
        "    #下一个时间步的输入是prefix里的字符或者当前的最佳预测字符\n",
        "    if t<len(prefix)-1:\n",
        "      output.append(char_to_idx[prefix[t+1]])\n",
        "    else:\n",
        "      output.append(tf.argmax(Y[0],axis=1).numpy()[0])\n",
        "  return ''.join([idx_to_char[i] for i in output])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ElIxPtT6f2zw",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def data_iter_random(corpus_indices,batch_size,num_steps):\n",
        "  #减1是因为输出的索引是相应输入的索引加1\n",
        "  num_examples=(len(corpus_indices)-1)//num_steps\n",
        "  epoch_size=num_examples//batch_size\n",
        "  example_indices=list(range(num_examples))\n",
        "  random.shuffle(example_indices)\n",
        "\n",
        "  #返回从pos开始的长为num_steps的序列\n",
        "  def _data(pos):\n",
        "    return corpus_indices[pos:pos+num_steps]\n",
        "\n",
        "  for i in range(epoch_size):\n",
        "    #每次读取batch_size个随机样本\n",
        "    i=i*batch_size\n",
        "    batch_indices=example_indices[i:i+batch_size]\n",
        "    X=[_data(j*num_steps) for j in batch_indices]\n",
        "    Y=[_data(j*num_steps+1) for j in batch_indices]\n",
        "    yield tf.constant(X),tf.constant(Y)\n",
        "\n",
        "def data_iter_consecutive(corpus_indices,batch_size,num_steps):\n",
        "  corpus_indices=tf.constant(corpus_indices)\n",
        "  data_len=len(corpus_indices)\n",
        "  batch_len=data_len//batch_size\n",
        "  indices=tf.reshape(corpus_indices[0:batch_size*batch_len],shape=(batch_size,batch_len))\n",
        "  epoch_size=(batch_len-1) // num_steps\n",
        "  for i in range(epoch_size):\n",
        "    i=i*num_steps\n",
        "    X=indices[:,i:i+num_steps]\n",
        "    Y=indices[:,i+1:i+num_steps+1]\n",
        "    yield X,Y    "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sIIQqSoyf43g",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def sgd(params,l,t,lr,batch_size,theta):\n",
        "  norm=tf.constant([0],dtype=tf.float32)\n",
        "  for param in params:\n",
        "    dl_dp=t.gradient(l,param)\n",
        "    if dl_dp is None:\n",
        "      print(param,dl_dp)\n",
        "    norm+=tf.reduce_sum((dl_dp**2))\n",
        "  norm=tf.sqrt(norm).numpy()\n",
        "  if norm>theta:\n",
        "    for param in params:\n",
        "      dl_dp=t.gradient(l,param) #求梯度\n",
        "      dl_dp=tf.assign(tf.Variable(dl_dp),dl_dp*theta/norm) #裁剪梯度\n",
        "      param.assign_sub(lr*dl_dp/batch_size) #更新梯度"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MBCv8G_rf75P",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def train_and_predict_rnn(rnn,get_params,init_rnn_state,num_hiddens,vocab_size,corpus_indices,idx_to_char,char_to_idx,is_random_iter,num_epochs,\n",
        "                          num_steps,lr,clipping_theta,batch_size,pred_period,pred_len,prefixes):\n",
        "  if is_random_iter:\n",
        "    data_iter_fn=data_iter_random\n",
        "  else:\n",
        "    data_iter_fn=data_iter_consecutive\n",
        "\n",
        "  params=get_params()\n",
        "  loss=losses.SparseCategoricalCrossentropy(from_logits=True)\n",
        "\n",
        "  for epoch in range(num_epochs):\n",
        "    if not is_random_iter:#如果使用相邻采样，在epoch开始时初始化隐藏状态\n",
        "      state=init_rnn_state(batch_size,num_hiddens)\n",
        "    l_sum,n,start=0.0,0,time.time()\n",
        "    data_iter=data_iter_fn(corpus_indices,batch_size,num_steps)\n",
        "    for X,Y in data_iter:\n",
        "      if is_random_iter:#如果使用相邻采样，在每个小批量更新前初始化隐藏状态\n",
        "        state=init_rnn_state(batch_size,num_hiddens)\n",
        "      else:#否则需要使用detach函数从\n",
        "        state=tf.stop_gradient(state) #停止计算该张量的梯度\n",
        "      with tf.GradientTape(persistent=True) as t:\n",
        "        t.watch(params)\n",
        "        inputs=to_onehot(X,vocab_size)\n",
        "        # outputs有num_steps个形状为(batch_size, vocab_size)的矩阵\n",
        "        (outputs,state)=rnn(inputs,state,params)\n",
        "        # 拼接之后形状为(num_steps * batch_size, vocab_size)\n",
        "        outputs=tf.concat(values=[*outputs],axis=0)\n",
        "        # Y的形状是(batch_size, num_steps)，转置后再变成长度为\n",
        "        # batch * num_steps 的向量，这样跟输出的行一一对应\n",
        "        y=tf.reshape(tf.transpose(Y),shape=(-1,))\n",
        "        #使用交叉熵损失计算平均分类误差\n",
        "        l=tf.reduce_mean(loss(y,outputs))\n",
        "      sgd(params,l,t,lr,1,clipping_theta) #因为误差已经取过均值了,所以batch_size为1\n",
        "      l_sum+=l.numpy()*y.numpy().shape[0]\n",
        "      n+=y.numpy().shape[0]\n",
        "    if(epoch +1)%10==0:\n",
        "      print('epoch %d, perplexity %f, time %.2f sec' % (\n",
        "                epoch + 1, l_sum / n, time.time() - start))\n",
        "      for prefix in prefixes:\n",
        "          print(' -', predict_rnn(\n",
        "              prefix, pred_len, rnn, params, init_rnn_state,\n",
        "              num_hiddens, vocab_size, idx_to_char, char_to_idx))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PTkIKF-Nf-BA",
        "colab_type": "code",
        "outputId": "7ded9c6f-b508-48a0-d83a-b49570ce3bbf",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 823
        }
      },
      "source": [
        "train_and_predict_rnn(lstm,get_params,init_lstm_state,num_hiddens,vocab_size,corpus_indices,idx_to_char,char_to_idx,False,num_epochs,num_steps,lr,clipping_theta,batch_size,pred_period,pred_len,prefixes)"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "epoch 10, perplexity 5.705122, time 21.27 sec\n",
            " - 分开  我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我\n",
            " - 不分开  我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我\n",
            "epoch 20, perplexity 5.638529, time 21.38 sec\n",
            " - 分开 我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我\n",
            " - 不分开 我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我我\n",
            "epoch 30, perplexity 5.541592, time 21.33 sec\n",
            " - 分开 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我\n",
            " - 不分开 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我\n",
            "epoch 40, perplexity 5.345972, time 21.29 sec\n",
            " - 分开 我不的 我不的 我不的 我不的 我不的 我不的 我不的 我不的 我不的 我不的 我不的 我不的 我\n",
            " - 不分开 我不的我 我不的 我不的 我不的 我不的 我不的 我不的 我不的 我不的 我不的 我不的 我不的 \n",
            "epoch 50, perplexity 5.092433, time 21.68 sec\n",
            " - 分开 我不的我 我不的 我不的 我不的 我不的 我不的 我不的 我不的 我不的 我不的 我不的 我不的 \n",
            " - 不分开 我想的我 我不的 我不的 我不的 我不的 我不的 我不的 我不的 我不的 我不的 我不的 我不的 \n",
            "epoch 60, perplexity 4.790641, time 21.77 sec\n",
            " - 分开 我想你的我 我不不我 我不不觉 我不不了 我不不了 我不不了 我不不了 我不不了 我不不了 我不不\n",
            " - 不分开 我想你你的我 我不不觉 我不不了 我不不了 我不不了 我不不了 我不不了 我不不了 我不不了 我不\n",
            "epoch 70, perplexity 4.483832, time 21.88 sec\n",
            " - 分开 我想你的爱我 你想我 我不我 我不要 我不不觉 我不不觉 我不不觉 我不不觉 我不了我 我不了 我\n",
            " - 不分开 你想我 你我的我 我不要 我不要 我不不觉 我不不觉 我不不觉 我不不觉 我不了我 我不了 我不么\n",
            "epoch 80, perplexity 4.169854, time 21.42 sec\n",
            " - 分开 我想你你的我 我不要这你 我不要这我 我不要这我 我不要这 我不要 我不了这我 我不要觉 我不了我\n",
            " - 不分开 我想你你的我 我不要这样 我不要这我 我不要这我 我不要这 我不要 我不了这我 我不要觉 我不了我\n",
            "epoch 90, perplexity 3.820179, time 21.27 sec\n",
            " - 分开 我想你你的你 我有你 你不我 我不要这生活 我不不觉 我不要这生活 我不不觉 我不要这生活 我不不\n",
            " - 不分开 我想你你的你 我想要你的你 我不要你 我不要 我不不觉 我不不觉 我不不觉 我不不觉 我不不觉 我\n",
            "epoch 100, perplexity 3.474981, time 21.29 sec\n",
            " - 分开 我想你你的你笑 一散 我不了我 我不要你的你 快爱在我不多 你知后觉 我想了这生活 后知后觉 我想\n",
            " - 不分开 我想你你的我 不要 你不的我有多 一话我 你给我 我想要这样 我不不觉 你不了觉开我 不知不觉 我\n",
            "epoch 110, perplexity 3.119084, time 21.16 sec\n",
            " - 分开 我想你你的微笑 一少在 我想我的你 有有  我给了了我我有多 你不多 你给我 我不眼 我不再 我不\n",
            " - 不分开 我想你你的微笑 就想你 你爱我 我想要这样 我不要这样活 不知后觉 我该了好节奏 后知后觉 我该好\n",
            "epoch 120, perplexity 2.740406, time 20.83 sec\n",
            " - 分开 我想你你已经 我不要你 你不我的爱笑 就情后觉 我想要这生活 我知好觉 我该好这生活 我知好好生活\n",
            " - 不分开 我想你你已经我每天 想想你 别你我抬起看 我不要你 我不要这生活 我知好觉 我不好这生活 我知好好\n",
            "epoch 130, perplexity 2.362516, time 20.89 sec\n",
            " - 分开 我想你你已很很 我想想你的微笑 想想你 你想我 我想再这样牵着我 甩发抖不不不痛 我想要你的微笑 \n",
            " - 不分开 我想你 你已我 我不再 是你怎么我说么 我说是你 是我我想想多 你你是对是我不知 别发抖 快给我抬\n",
            "epoch 140, perplexity 2.085677, time 20.54 sec\n",
            " - 分开 我想带你二堡 我想要你的微笑 就想要 说你眼睛看着我 别发抖 快给我抬起头 有话去对医药箱说 别怪\n",
            " - 不分开 我已你 你怎么 一手不好 我的再空 我不懂痛着我 没你了人的手 快你在我 你怪了 我怎么 停么的停\n",
            "epoch 150, perplexity 1.701517, time 20.69 sec\n",
            " - 分开 你已你 你是我 一些些酒 在一盘  印地的美片 老唱盘 温皮了 恨满了 告诉安的停留 还真么 瞎什\n",
            " - 不分开 我已你 你不么 一手怎么 说你的没 说的梦空 是你的梦 你不放痛 我不懂 痛不了 是不中 停不中 \n",
            "epoch 160, perplexity 1.439472, time 21.30 sec\n",
            " - 分开 你已你 是什么 一九 在九在 在一风 在诉风 装一风 旧一风 在一风 在一风 在一风 不变 装变风\n",
            " - 不分开 我已经不个奏 后知后觉 满脸了双落棍 哼哼哈兮 快使用双截棍 哼哼哈兮 快使用双截棍 哼哼哈兮 快\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GYiXIaQtgV6N",
        "colab_type": "text"
      },
      "source": [
        "### 6.8.5. 简洁实现"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_4z6bIPHap48",
        "colab_type": "code",
        "outputId": "0f2f65e5-c2cf-4fb2-9ed9-3a84249124a6",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 269
        }
      },
      "source": [
        "BATCH_SIZE=64\n",
        "BUFFER_SIZE=1000\n",
        "seq_length=100\n",
        "def make_dataset():\n",
        "  examples_per_epoch=len(corpus_indices) //seq_length\n",
        "  char_dataset=Dataset.from_tensor_slices(np.array(corpus_indices))\n",
        "  sequences=char_dataset.batch(seq_length+1,drop_remainder=True)\n",
        "  def split_input_target(chunk):\n",
        "    input_text=chunk[:-1]\n",
        "    target_text=chunk[1:]\n",
        "    return input_text,target_text\n",
        "  dataset=sequences.map(split_input_target)\n",
        "  setps_per_epoch=examples_per_epoch // BATCH_SIZE\n",
        "  dataset=dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE,drop_remainder=True)\n",
        "  return dataset,setps_per_epoch\n",
        "dataset,setps_per_epoch=make_dataset()\n",
        "for x in dataset:\n",
        "  print(x)\n",
        "  break"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "(<tf.Tensor: id=131, shape=(64, 100), dtype=int64, numpy=\n",
            "array([[ 914,  971,  896, ...,  328,  153,  713],\n",
            "       [1016,  665,  608, ...,  914,  552,  502],\n",
            "       [ 971,  608,  814, ...,  292,  730,  914],\n",
            "       ...,\n",
            "       [ 730,  914,  446, ...,    9,  601,   70],\n",
            "       [ 238,  178,  879, ...,    5,   32,  469],\n",
            "       [ 751,  751,  688, ...,  914,  540,  963]])>, <tf.Tensor: id=132, shape=(64, 100), dtype=int64, numpy=\n",
            "array([[971, 896, 324, ..., 153, 713, 914],\n",
            "       [665, 608, 715, ..., 552, 502, 929],\n",
            "       [608, 814, 375, ..., 730, 914, 214],\n",
            "       ...,\n",
            "       [914, 446, 768, ..., 601,  70, 914],\n",
            "       [178, 879, 730, ...,  32, 469,  44],\n",
            "       [751, 688, 563, ..., 540, 963, 502]])>)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ERm2ikJybHBO",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def train_model(num_hiddens=256,embedding_dim=256,epochs=500):\n",
        "  net=keras.Sequential()\n",
        "  net.add(keras.layers.Embedding(input_dim=vocab_size,output_dim=vocab_size,batch_input_shape=(BATCH_SIZE,seq_length)))\n",
        "  net.add(keras.layers.LSTM(num_hiddens,unroll=True,return_sequences=True,stateful=True))\n",
        "  net.add(keras.layers.Dense(vocab_size,activation='softmax'))\n",
        "  net.compile(optimizer=keras.optimizers.Adam(),loss=losses.SparseCategoricalCrossentropy(),metrics=['acc'])\n",
        "  net.fit_generator(dataset.repeat(),steps_per_epoch=setps_per_epoch,epochs=epochs)\n",
        "  return net"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BfyHT45kb1qD",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "net=train_model()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZqqS5Zfyb_RF",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def generate_text(source_net,pred_len=50,prefix='分开'):\n",
        "  #因为训练的网络是stateful，在keras中使用其预测时输入的向量shape必须跟训练时输入的向量shape一致,\n",
        "  #但是这里我们我们生成文本只需要输入几个前缀文字，因此重新定义一个新模型，并修改输向量的shape，然后使用原有模型的权重\n",
        "  num_hiddens=256\n",
        "  new_net=keras.Sequential()\n",
        "  new_net.add(keras.layers.Embedding(input_dim=vocab_size,output_dim=vocab_size,batch_input_shape=(1,1)))\n",
        "  new_net.add(keras.layers.LSTM(num_hiddens,unroll=True,return_sequences=True,stateful=True))\n",
        "  new_net.add(keras.layers.Dense(vocab_size,activation='softmax'))\n",
        "  new_net.set_weights(source_net.get_weights())\n",
        "  new_net.compile(optimizer=keras.optimizers.Adam(),loss=losses.SparseCategoricalCrossentropy(),metrics=['acc'])\n",
        "  text_generated=prefix\n",
        "  for i in range(pred_len):\n",
        "    id=char_to_idx[text_generated[-1]]\n",
        "    char=idx_to_char[tf.argmax(new_net.predict(tf.constant(value=[id]))[0],axis=-1).numpy()[0]]\n",
        "    text_generated+=char\n",
        "  return text_generated"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "z95aNXoFfQnI",
        "colab_type": "code",
        "outputId": "e3776213-e009-4c57-a78c-3ab9adf4e8b2",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "print(generate_text(net))"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "分开不 我不能再想 我不能 爱情走的太快就像龙卷风 不能承受我已无处可躲 我不要再想 我不要再想 我不 \n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5w4Y-EgkfU6Y",
        "colab_type": "code",
        "outputId": "c230c399-62e1-4432-fb36-7793a8c3270f",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 54
        }
      },
      "source": [
        "print(generate_text(net,pred_len=100,prefix='不分开'))"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "不分开不 我不能再想 我不能 爱情走的太快就像龙卷风 不能承受我已无处可躲 我不要再想 我不要再想 我不 我不 我不要再想 我不要再想 我不 我不 我不要再想 我不 我不能再想 我不能再想 我不 我不能 爱\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DJcLxZuPf_e8",
        "colab_type": "code",
        "outputId": "7c5243cf-37c2-4abc-eafc-4833b9b71e70",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 54
        }
      },
      "source": [
        "print(generate_text(net,pred_len=100,prefix='你发如'))"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "你发如果我遇见你是一场悲剧 我想就这样牵着你的手不放开 爱能不能够永远单纯没有悲哀 我 想带你骑单车 我 想和你看棒球 想这样没担忧 唱着歌 一直走 我想就这样牵着你的手不放开 爱可不可以简简单单没有伤害 \n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qEYVJNrNgJab",
        "colab_type": "text"
      },
      "source": [
        "###6.8.6. 小结\n",
        "* 长短期记忆的隐藏层输出包括隐藏状态和记忆细胞。只有隐藏状态会传递到输出层。\n",
        "* 长短期记忆的输入门、遗忘门和输出门可以控制信息的流动。\n",
        "* 长短期记忆可以应对循环神经网络中的梯度衰减问题，并更好地捕捉时间序列中时间步距离较大的依赖关系。"
      ]
    }
  ]
}
